{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.5/dist-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import collections"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_dataset(words, n_words):\n",
    "    count = [['GO', 0], ['PAD', 1], ['EOS', 2], ['UNK', 3]]\n",
    "    count.extend(collections.Counter(words).most_common(n_words - 1))\n",
    "    dictionary = dict()\n",
    "    for word, _ in count:\n",
    "        dictionary[word] = len(dictionary)\n",
    "    data = list()\n",
    "    unk_count = 0\n",
    "    for word in words:\n",
    "        index = dictionary.get(word, 0)\n",
    "        if index == 0:\n",
    "            unk_count += 1\n",
    "        data.append(index)\n",
    "    count[0][1] = unk_count\n",
    "    reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys()))\n",
    "    return data, count, dictionary, reversed_dictionary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "len from: 265, len to: 265\n"
     ]
    }
   ],
   "source": [
    "with open('from.txt', 'r') as fopen:\n",
    "    text_from = fopen.read().lower().split('\\n')\n",
    "with open('to.txt', 'r') as fopen:\n",
    "    text_to = fopen.read().lower().split('\\n')\n",
    "print('len from: %d, len to: %d'%(len(text_from), len(text_to)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vocab from size: 331\n",
      "Most common words [('you', 73), ('is', 67), ('what', 63), ('a', 49), ('the', 40), ('do', 36)]\n",
      "Sample data [104, 62, 259, 62, 173, 62, 292, 62, 211, 11] ['hi', 'good', 'morning', 'good', 'afternoon', 'good', 'evening', 'good', 'night', 'how']\n"
     ]
    }
   ],
   "source": [
    "concat_from = ' '.join(text_from).split()\n",
    "vocabulary_size_from = len(list(set(concat_from)))\n",
    "data_from, count_from, dictionary_from, rev_dictionary_from = build_dataset(concat_from, vocabulary_size_from)\n",
    "print('vocab from size: %d'%(vocabulary_size_from))\n",
    "print('Most common words', count_from[4:10])\n",
    "print('Sample data', data_from[:10], [rev_dictionary_from[i] for i in data_from[:10]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vocab to size: 504\n",
      "Most common words [('i', 127), ('you', 55), ('a', 47), ('to', 44), ('the', 40), ('it', 38)]\n",
      "Sample data [186, 107, 35, 228, 35, 385, 35, 332, 35, 424] ['hi', 'there', 'good', 'morning', 'good', 'afternoon', 'good', 'evening', 'good', 'night']\n"
     ]
    }
   ],
   "source": [
    "concat_to = ' '.join(text_to).split()\n",
    "vocabulary_size_to = len(list(set(concat_to)))\n",
    "data_to, count_to, dictionary_to, rev_dictionary_to = build_dataset(concat_to, vocabulary_size_to)\n",
    "print('vocab to size: %d'%(vocabulary_size_to))\n",
    "print('Most common words', count_to[4:10])\n",
    "print('Sample data', data_to[:10], [rev_dictionary_to[i] for i in data_to[:10]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "GO = dictionary_from['GO']\n",
    "PAD = dictionary_from['PAD']\n",
    "EOS = dictionary_from['EOS']\n",
    "UNK = dictionary_from['UNK']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Chatbot:\n",
    "    def __init__(self, size_layer, num_layers, embedded_size, \n",
    "                 from_dict_size, to_dict_size, learning_rate, \n",
    "                 batch_size, dropout = 0.5, beam_width = 15):\n",
    "        \n",
    "        def lstm_cell(size, reuse=False):\n",
    "            return tf.nn.rnn_cell.BasicRNNCell(size, reuse=reuse)\n",
    "        \n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None, None])\n",
    "        self.X_seq_len = tf.placeholder(tf.int32, [None])\n",
    "        self.Y_seq_len = tf.placeholder(tf.int32, [None])\n",
    "        # encoder\n",
    "        encoder_embeddings = tf.Variable(tf.random_uniform([from_dict_size, embedded_size], -1, 1))\n",
    "        encoder_embedded = tf.nn.embedding_lookup(encoder_embeddings, self.X)\n",
    "        for n in range(num_layers):\n",
    "            (out_fw, out_bw), (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn(\n",
    "                cell_fw = lstm_cell(size_layer // 2),\n",
    "                cell_bw = lstm_cell(size_layer // 2),\n",
    "                inputs = encoder_embedded,\n",
    "                sequence_length = self.X_seq_len,\n",
    "                dtype = tf.float32,\n",
    "                scope = 'bidirectional_rnn_%d'%(n))\n",
    "            encoder_embedded = tf.concat((out_fw, out_bw), 2)\n",
    "        \n",
    "        bi_state = tf.concat((state_fw, state_bw), -1)\n",
    "        self.encoder_state = tuple([bi_state] * num_layers)\n",
    "        \n",
    "        self.encoder_state = tuple(self.encoder_state[-1] for _ in range(num_layers))\n",
    "        main = tf.strided_slice(self.Y, [0, 0], [batch_size, -1], [1, 1])\n",
    "        decoder_input = tf.concat([tf.fill([batch_size, 1], GO), main], 1)\n",
    "        # decoder\n",
    "        decoder_embeddings = tf.Variable(tf.random_uniform([to_dict_size, embedded_size], -1, 1))\n",
    "        decoder_cells = tf.nn.rnn_cell.MultiRNNCell([lstm_cell(size_layer) for _ in range(num_layers)])\n",
    "        dense_layer = tf.layers.Dense(to_dict_size)\n",
    "        training_helper = tf.contrib.seq2seq.TrainingHelper(\n",
    "                inputs = tf.nn.embedding_lookup(decoder_embeddings, decoder_input),\n",
    "                sequence_length = self.Y_seq_len,\n",
    "                time_major = False)\n",
    "        training_decoder = tf.contrib.seq2seq.BasicDecoder(\n",
    "                cell = decoder_cells,\n",
    "                helper = training_helper,\n",
    "                initial_state = self.encoder_state,\n",
    "                output_layer = dense_layer)\n",
    "        training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "                decoder = training_decoder,\n",
    "                impute_finished = True,\n",
    "                maximum_iterations = tf.reduce_max(self.Y_seq_len))\n",
    "        predicting_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(\n",
    "                embedding = decoder_embeddings,\n",
    "                start_tokens = tf.tile(tf.constant([GO], dtype=tf.int32), [batch_size]),\n",
    "                end_token = EOS)\n",
    "        predicting_decoder = tf.contrib.seq2seq.BasicDecoder(\n",
    "                cell = decoder_cells,\n",
    "                helper = predicting_helper,\n",
    "                initial_state = self.encoder_state,\n",
    "                output_layer = dense_layer)\n",
    "        predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "                decoder = predicting_decoder,\n",
    "                impute_finished = True,\n",
    "                maximum_iterations = 2 * tf.reduce_max(self.X_seq_len))\n",
    "        self.training_logits = training_decoder_output.rnn_output\n",
    "        self.predicting_ids = predicting_decoder_output.sample_id\n",
    "        masks = tf.sequence_mask(self.Y_seq_len, tf.reduce_max(self.Y_seq_len), dtype=tf.float32)\n",
    "        self.cost = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits,\n",
    "                                                     targets = self.Y,\n",
    "                                                     weights = masks)\n",
    "        self.optimizer = tf.train.AdamOptimizer(learning_rate).minimize(self.cost)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "size_layer = 128\n",
    "num_layers = 2\n",
    "embedded_size = 128\n",
    "learning_rate = 0.001\n",
    "batch_size = 32\n",
    "epoch = 50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Chatbot(size_layer, num_layers, embedded_size, vocabulary_size_from + 4, \n",
    "                vocabulary_size_to + 4, learning_rate, batch_size)\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def str_idx(corpus, dic):\n",
    "    X = []\n",
    "    for i in corpus:\n",
    "        ints = []\n",
    "        for k in i.split():\n",
    "            try:\n",
    "                ints.append(dic[k])\n",
    "            except Exception as e:\n",
    "                print(e)\n",
    "                ints.append(2)\n",
    "        X.append(ints)\n",
    "    return X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "'false'\n",
      "'genius'\n"
     ]
    }
   ],
   "source": [
    "X = str_idx(text_from, dictionary_from)\n",
    "Y = str_idx(text_to, dictionary_to)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_sentence_batch(sentence_batch, pad_int):\n",
    "    padded_seqs = []\n",
    "    seq_lens = []\n",
    "    max_sentence_len = max([len(sentence) for sentence in sentence_batch])\n",
    "    for sentence in sentence_batch:\n",
    "        padded_seqs.append(sentence + [pad_int] * (max_sentence_len - len(sentence)))\n",
    "        seq_lens.append(len(sentence))\n",
    "    return padded_seqs, seq_lens\n",
    "\n",
    "def check_accuracy(logits, Y):\n",
    "    acc = 0\n",
    "    for i in range(logits.shape[0]):\n",
    "        internal_acc = 0\n",
    "        for k in range(len(Y[i])):\n",
    "            try:\n",
    "                if Y[i][k] == logits[i][k]:\n",
    "                    internal_acc += 1\n",
    "            except:\n",
    "                continue\n",
    "        acc += (internal_acc / len(Y[i]))\n",
    "    return acc / logits.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1, avg loss: 6.177050, avg accuracy: 0.006740\n",
      "epoch: 2, avg loss: 5.703311, avg accuracy: 0.030387\n",
      "epoch: 3, avg loss: 5.279043, avg accuracy: 0.032690\n",
      "epoch: 4, avg loss: 4.952622, avg accuracy: 0.032662\n",
      "epoch: 5, avg loss: 4.675222, avg accuracy: 0.028388\n",
      "epoch: 6, avg loss: 4.440371, avg accuracy: 0.039776\n",
      "epoch: 7, avg loss: 4.223120, avg accuracy: 0.041956\n",
      "epoch: 8, avg loss: 4.011665, avg accuracy: 0.044920\n",
      "epoch: 9, avg loss: 3.800355, avg accuracy: 0.046654\n",
      "epoch: 10, avg loss: 3.587758, avg accuracy: 0.052607\n",
      "epoch: 11, avg loss: 3.373784, avg accuracy: 0.059839\n",
      "epoch: 12, avg loss: 3.159247, avg accuracy: 0.068065\n",
      "epoch: 13, avg loss: 2.945502, avg accuracy: 0.080093\n",
      "epoch: 14, avg loss: 2.733823, avg accuracy: 0.092254\n",
      "epoch: 15, avg loss: 2.525805, avg accuracy: 0.105387\n",
      "epoch: 16, avg loss: 2.325343, avg accuracy: 0.122158\n",
      "epoch: 17, avg loss: 2.138892, avg accuracy: 0.138479\n",
      "epoch: 18, avg loss: 2.004267, avg accuracy: 0.158598\n",
      "epoch: 19, avg loss: 1.929683, avg accuracy: 0.179286\n",
      "epoch: 20, avg loss: 1.703825, avg accuracy: 0.167576\n",
      "epoch: 21, avg loss: 1.528182, avg accuracy: 0.201227\n",
      "epoch: 22, avg loss: 1.370113, avg accuracy: 0.224063\n",
      "epoch: 23, avg loss: 1.229174, avg accuracy: 0.252597\n",
      "epoch: 24, avg loss: 1.101900, avg accuracy: 0.269698\n",
      "epoch: 25, avg loss: 0.987202, avg accuracy: 0.292782\n",
      "epoch: 26, avg loss: 0.885681, avg accuracy: 0.302628\n",
      "epoch: 27, avg loss: 0.794274, avg accuracy: 0.326227\n",
      "epoch: 28, avg loss: 0.711147, avg accuracy: 0.336566\n",
      "epoch: 29, avg loss: 0.636668, avg accuracy: 0.349356\n",
      "epoch: 30, avg loss: 0.569845, avg accuracy: 0.362920\n",
      "epoch: 31, avg loss: 0.512654, avg accuracy: 0.373077\n",
      "epoch: 32, avg loss: 0.463061, avg accuracy: 0.375649\n",
      "epoch: 33, avg loss: 0.416303, avg accuracy: 0.381827\n",
      "epoch: 34, avg loss: 0.373228, avg accuracy: 0.385233\n",
      "epoch: 35, avg loss: 0.336349, avg accuracy: 0.385920\n",
      "epoch: 36, avg loss: 0.304561, avg accuracy: 0.388714\n",
      "epoch: 37, avg loss: 0.277012, avg accuracy: 0.392505\n",
      "epoch: 38, avg loss: 0.253260, avg accuracy: 0.393001\n",
      "epoch: 39, avg loss: 0.232312, avg accuracy: 0.394316\n",
      "epoch: 40, avg loss: 0.213747, avg accuracy: 0.393231\n",
      "epoch: 41, avg loss: 0.196835, avg accuracy: 0.394316\n",
      "epoch: 42, avg loss: 0.181782, avg accuracy: 0.394316\n",
      "epoch: 43, avg loss: 0.168571, avg accuracy: 0.394316\n",
      "epoch: 44, avg loss: 0.156900, avg accuracy: 0.394316\n",
      "epoch: 45, avg loss: 0.146453, avg accuracy: 0.394316\n",
      "epoch: 46, avg loss: 0.137037, avg accuracy: 0.394316\n",
      "epoch: 47, avg loss: 0.128580, avg accuracy: 0.394316\n",
      "epoch: 48, avg loss: 0.120896, avg accuracy: 0.394316\n",
      "epoch: 49, avg loss: 0.113954, avg accuracy: 0.394316\n",
      "epoch: 50, avg loss: 0.107593, avg accuracy: 0.394316\n"
     ]
    }
   ],
   "source": [
    "for i in range(epoch):\n",
    "    total_loss, total_accuracy = 0, 0\n",
    "    for k in range(0, (len(text_from) // batch_size) * batch_size, batch_size):\n",
    "        batch_x, seq_x = pad_sentence_batch(X[k: k+batch_size], PAD)\n",
    "        batch_y, seq_y = pad_sentence_batch(Y[k: k+batch_size], PAD)\n",
    "        predicted, loss, _ = sess.run([model.predicting_ids, model.cost, model.optimizer], \n",
    "                                      feed_dict={model.X:batch_x,\n",
    "                                                model.Y:batch_y,\n",
    "                                                model.X_seq_len:seq_x,\n",
    "                                                model.Y_seq_len:seq_y})\n",
    "        total_loss += loss\n",
    "        total_accuracy += check_accuracy(predicted,batch_y)\n",
    "    total_loss /= (len(text_from) // batch_size)\n",
    "    total_accuracy /= (len(text_from) // batch_size)\n",
    "    print('epoch: %d, avg loss: %f, avg accuracy: %f'%(i+1, total_loss, total_accuracy))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
